package cn.kennylee.codehub.mybatis.das.extension;

import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ReflectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.core.util.TypeUtil;
import cn.hutool.json.JSONUtil;
import cn.kennylee.codehub.common.das.dto.PageDto;
import cn.kennylee.codehub.mybatis.das.eo.BaseEo;
import cn.kennylee.codehub.mybatis.das.utils.DasHelper;
import com.baomidou.mybatisplus.core.enums.SqlMethod;
import com.baomidou.mybatisplus.core.exceptions.MybatisPlusException;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.metadata.TableInfo;
import com.baomidou.mybatisplus.core.metadata.TableInfoHelper;
import com.baomidou.mybatisplus.core.toolkit.Assert;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.Constants;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.baomidou.mybatisplus.extension.toolkit.SqlHelper;
import com.github.yulichang.query.MPJQueryWrapper;
import jakarta.annotation.Resource;
import org.apache.ibatis.binding.MapperMethod;
import org.apache.ibatis.session.SqlSession;
import org.apache.ibatis.session.SqlSessionFactory;
import org.mybatis.spring.SqlSessionHolder;
import org.slf4j.Logger;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.transaction.support.TransactionSynchronizationManager;

import java.io.Serializable;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Types;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * <p> 抽象Das，扩展业务特性 </p>
 * <p> 该类是一个抽象类，提供了一些基础的方法，如获取mapper、获取实体类、批量插入、批量更新等 </p>
 * <p>Created on 2024/5/15.</p>
 * <ol>
 *     <li>T: 实体类</li>
 *     <li>P: 主键类型</li>
 * </ol>
 *
 * @author kennylee
 * @since 0.0.1
 */
public abstract class AbstractDas<T extends BaseEo<P>, P extends Serializable> extends ComBaseDas<T, P> implements ApplicationContextAware {

    private final Logger logger = org.slf4j.LoggerFactory.getLogger(getClass());

    private ApplicationContext applicationContext;

    @Resource
    private Map<String, BaseMapper<T>> mappers;

    private BaseMapper<T> baseMapper;

    protected Class<T> entityClass = currentModelClass();

    @Override
    public BaseMapper<T> getBaseMapper() {
        if (this.baseMapper != null) {
            return this.baseMapper;
        } else {
            String tClassName = getEoClassName();
            String firstSearchMapperName = StrUtil.lowerFirst(StrUtil.subBefore(tClassName, "Eo", true)) + "Mapper";

            this.baseMapper = this.getMappers().get(firstSearchMapperName);
            String secondSearchMapperName = null;
            if (null == this.baseMapper) {
                secondSearchMapperName = tClassName.substring(0, tClassName.length() - 2) + "Mapper";
                this.baseMapper = this.getMappers().get(secondSearchMapperName);
            }

            if (this.baseMapper == null) {
                throw new RuntimeException(tClassName + "找不到对应的mapper:" + firstSearchMapperName + "/" + secondSearchMapperName);
            } else {
                return this.baseMapper;
            }
        }
    }

    @NonNull
    private String getEoClassName() {
        return this.getEntityClass().getSimpleName();
    }

    @Override
    public Class<T> getEntityClass() {
        return entityClass;
    }

    @Override
    @SuppressWarnings("unchecked")
    protected Class<T> currentModelClass() {
        return (Class<T>) TypeUtil.getClass(TypeUtil.getTypeArgument(this.getClass()));
    }

    protected Map<String, BaseMapper<T>> getMappers() {
        return this.mappers;
    }

    @Override
    public boolean saveBatch(Collection<T> entityList, int batchSize) {
        if (CollectionUtils.isEmpty(entityList)) {
            throw new IllegalArgumentException("插入对象列表不能为空!");
        }

        if (isSupportsMySqlSyntax()) {
            return SqlHelper.retBool(batchInsertForMySql(entityList, batchSize));
        } else {
            return commonSaveBatch(entityList, batchSize);
        }
    }

    /**
     * <p>动态生成MySQL的批量插入SQL语句，然后提交</p>
     * <p>由于 SQL 语句只解析一次，网络传输和数据库处理的效率显著提升，尤其是在插入大量数据时，但受 max_allowed_packet 和 SQL 长度限制</p>
     *
     * @param entities 实体对象列表，注意限制长度
     */
    protected int batchInsertForMySql(@NonNull Collection<T> entities, int batchSize) {
        if (CollUtil.isEmpty(entities)) {
            return 0;
        }

        // 执行数据实体属性加工
        @SuppressWarnings("rawtypes")
        ObjectProvider<BatchSavePropsProvider> batchSavePropsProviders = applicationContext.getBeanProvider(BatchSavePropsProvider.class);
        batchSavePropsProviders.stream().forEach(provider -> provider.provide(entities, TableInfoHelper.getTableInfo(getEntityClass())));

        // 按 batchSize 分割实体列表
        final List<List<T>> splitEntityList = CollUtil.split(entities, batchSize);

        SqlSessionFactory sqlSessionFactory = SqlHelper.sqlSessionFactory(getEntityClass());

        SqlSessionHolder sqlSessionHolder = (SqlSessionHolder) TransactionSynchronizationManager.getResource(sqlSessionFactory);
        boolean transaction = TransactionSynchronizationManager.isSynchronizationActive();
        if (sqlSessionHolder != null) {
            SqlSession sqlSession = sqlSessionHolder.getSqlSession();
            //原生无法支持执行器切换，当存在批量操作时，会嵌套两个session的，优先commit上一个session
            //按道理来说，这里的值应该一直为false。
            sqlSession.commit(!transaction);
        }

        // 获取所有字段，key为属性名，value为字段名
        final Map<String, String> columnPropNameMap = getColumnPropNameMap();
        AtomicInteger insertedCount = new AtomicInteger();

        // try catch 自动管理sqlSession.close()
        try (SqlSession sqlSession = sqlSessionFactory.openSession()) {
            // 用并行流可能会引起并发问题，尤其是在同一个SqlSession中并行处理，仅使用forEach串行执行
            splitEntityList.forEach(entityList -> {
                insertedCount.addAndGet(doBatchSave(sqlSession, entityList, columnPropNameMap, transaction));
            });
        }
        return insertedCount.get();
    }

    /**
     * 执行批量插入
     *
     * @param sqlSession        SqlSession
     * @param entityList        实体列表
     * @param columnPropNameMap 实体的字段名映射
     * @param transaction       是否事务
     * @return 插入的行数
     */
    protected int doBatchSave(SqlSession sqlSession, List<T> entityList, Map<String, String> columnPropNameMap, boolean transaction) {
        // 统计每个字段在所有实体中的值，如果某字段在所有实体中都为null，则跳过
        final Map<String, String> nonNullColumnPropNameMap = columnPropNameMap.entrySet().stream()
            .filter(entry -> entityList.parallelStream().anyMatch(entity -> Objects.nonNull(ReflectUtil.getFieldValue(entity, entry.getKey()))))
            .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (k1, k2) -> k1, LinkedHashMap::new));

        if (logger.isDebugEnabled()) {
            logger.debug("非空字段名映射: {}", JSONUtil.toJsonStr(nonNullColumnPropNameMap));
        }

        // 构建 SQL 语句
        String sql = buildInsertSql(entityList, nonNullColumnPropNameMap);
        if (log.isDebugEnabled()) {
            log.debug(StrUtil.format("插入数据条数: {}, 批量插入SQL: {}", entityList.size(), sql));
        }

        Connection connection = sqlSession.getConnection();
        try (PreparedStatement stmt = connection.prepareStatement(sql)) {
            // 为每个分片设置参数
            setBatchParameters(stmt, entityList, List.copyOf(nonNullColumnPropNameMap.keySet()));
            // 添加到批处理
            stmt.addBatch();
            // 执行插入
            int effectRows = stmt.executeUpdate();

            // 非事务情况下，强制commit。
            if (false == transaction) {
                logger.debug("非事务情况下，强制commit");
                sqlSession.commit(true);
            }

            // 累加插入的行数
            return effectRows;
        } catch (SQLException e) {
            log.error(StrUtil.format("SQL 执行失败: {}", e.getMessage()), e);
            sqlSession.rollback();
            throw new MybatisPlusException("批量插入失败: " + e.getMessage(), e);
        } catch (Exception e) {
            log.error(StrUtil.format("未知错误: {}", e.getMessage()), e);
            sqlSession.rollback();
            throw new RuntimeException("批量插入失败", e);
        }
    }

    /**
     * 构建插入 SQL 语句
     *
     * @param entityList               实体列表
     * @param nonNullColumnPropNameMap 非空字段名映射
     */
    private String buildInsertSql(Collection<T> entityList,
                                  Map<String, String> nonNullColumnPropNameMap) {
        // 构建字段名部分
        String columnNames = StrUtil.join(", ", nonNullColumnPropNameMap.values());

        // 构建占位符 (?,?,?)，每个实体一组
        String singleValuePlaceholders = "(" + StrUtil.join(", ", Collections.nCopies(nonNullColumnPropNameMap.size(), "?")) + ")";
        String allValuesPlaceholders = IntStream.range(0, entityList.size())
            .mapToObj(i -> singleValuePlaceholders)
            .collect(Collectors.joining(", "));

        // 最终 SQL 语句
        return StrUtil.format("INSERT INTO {} ({}) VALUES {}", getTableName(), columnNames, allValuesPlaceholders);
    }

    /**
     * 为分片设置参数
     */
    protected void setBatchParameters(PreparedStatement preparedStatement,
                                      List<T> entities,
                                      List<String> nonNullColumnPropNames) throws SQLException {
        AtomicInteger paramIndex = new AtomicInteger(1);
        for (T entity : entities) {
            for (String propName : nonNullColumnPropNames) {
                Object value = ReflectUtil.getFieldValue(entity, propName);
                if (Objects.nonNull(value)) {
                    preparedStatement.setObject(paramIndex.getAndIncrement(), value);
                } else {
                    preparedStatement.setNull(paramIndex.getAndIncrement(), Types.NULL);
                }
            }
        }
    }

    /**
     * 通用批量插入防范
     *
     * @param entityList 插入对象列表
     * @param batchSize  批量大小
     * @return 是否成功
     */
    protected boolean commonSaveBatch(Collection<T> entityList, int batchSize) {
        int i = 0;
        try {
            SqlSession batchSqlSession = this.sqlSessionBatch();
            Throwable var4 = null;

            try {
                String sqlStatement = this.getSqlStatement(SqlMethod.INSERT_ONE);

                for (Iterator<T> var7 = entityList.iterator(); var7.hasNext(); ++i) {
                    T entity = (T) var7.next();
                    batchSqlSession.insert(sqlStatement, entity);
                    if (i >= 1 && i % batchSize == 0) {
                        batchSqlSession.flushStatements();
                    }
                }

                batchSqlSession.flushStatements();
            } catch (Throwable var17) {
                var4 = var17;
                throw var17;
            } finally {
                if (batchSqlSession != null) {
                    if (var4 != null) {
                        try {
                            batchSqlSession.close();
                        } catch (Throwable var16) {
                            var4.addSuppressed(var16);
                        }
                    } else {
                        batchSqlSession.close();
                    }
                }
            }

            return true;
        } catch (Throwable var19) {
            log.error(var19.getMessage(), var19);
            logger.error("批量插入入参: {}", JSONUtil.toJsonStr(entityList));
            throw new RuntimeException("批量插入失败！", var19);
        }
    }

    /**
     * 获取mapperStatementId
     *
     * @param sqlMethod 方法名
     * @return 命名id
     * @since 3.4.0
     */
    protected String getSqlStatement(SqlMethod sqlMethod) {
        Class<?> mapper = getMapperClass();
        return mapper.getName() + "." + sqlMethod.getMethod();
    }

    @SuppressWarnings({"unchecked", "rawtypes"})
    protected Class<? extends BaseMapper> getMapperClass() {
        Class<? extends BaseMapper> clazz = this.getBaseMapper().getClass();
        if (clazz.getName().contains("$Proxy")) {
            clazz = (Class<? extends BaseMapper>) clazz.getInterfaces()[0];
        }
        return clazz;
    }

    @Override
    @Transactional(rollbackFor = Exception.class)
    public boolean saveOrUpdateBatch(Collection<T> entityList, int batchSize) {
        TableInfo tableInfo = TableInfoHelper.getTableInfo(getEntityClass());
        Assert.notNull(tableInfo, "error: can not execute. because can not find cache of TableInfo for entity!");
        String keyProperty = tableInfo.getKeyProperty();
        Assert.notEmpty(keyProperty, "error: can not execute. because can not find column for id from entity!");
        return SqlHelper.saveOrUpdateBatch(this.entityClass, this.getBaseMapper().getClass(), this.log, entityList, batchSize, (sqlSession, entity) -> {
            Object idVal = tableInfo.getPropertyValue(entity, keyProperty);
            return StringUtils.checkValNull(idVal)
                || CollectionUtils.isEmpty(sqlSession.selectList(getSqlStatement(SqlMethod.SELECT_BY_ID), entity));
        }, (sqlSession, entity) -> {
            MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();
            param.put(Constants.ENTITY, entity);
            sqlSession.update(getSqlStatement(SqlMethod.UPDATE_BY_ID), param);
        });
    }

    @Transactional(rollbackFor = Exception.class)
    @Override
    public boolean updateBatchById(Collection<T> entityList, int batchSize) {
        String sqlStatement = getSqlStatement(SqlMethod.UPDATE_BY_ID);
        return executeBatch(entityList, batchSize, (sqlSession, entity) -> {
            MapperMethod.ParamMap<T> param = new MapperMethod.ParamMap<>();

            if (Objects.isNull(entity.getId())) {
                throw new IllegalArgumentException("Id 不能为空！");
            }

            param.put(Constants.ENTITY, entity);
            sqlSession.update(sqlStatement, param);
        });
    }

    /**
     * <p>根据分页参数对象，执行分页查询</p>
     * <p>适用于绝大部分单表分页查询</p>
     *
     * @param pageDto 分页参数对象
     * @param <E>     分页参数对象
     * @return 分页结果
     */
    public <E extends PageDto> IPage<T> findPage(@NonNull E pageDto) {
        return findPage(pageDto, null);
    }

    /**
     * <p>根据分页参数对象，执行分页查询</p>
     * <p>适用于绝大部分单表分页查询</p>
     *
     * @param pageDto      分页参数对象
     * @param consumerFunc wrapper增强方法逻辑
     * @param <E>          分页参数对象
     * @return 分页结果
     */
    public <E extends PageDto> IPage<T> findPage(@NonNull E pageDto, @Nullable ExtWrapperConsumerFunc<T> consumerFunc) {
        MPJQueryWrapper<T> queryChainWrapper = DasHelper.buildChainWrapperWithLike(pageDto, entityClass);
        if (Objects.nonNull(consumerFunc)) {
            consumerFunc.accept(queryChainWrapper);
        }
        int sorts = DasHelper.addSorts(queryChainWrapper, pageDto);
        // 默认排序，保证排序的稳定性
        if (sorts == 0) {
            DasHelper.addBaseUniqueSort(queryChainWrapper);
        }
        return this.page(buildQueryPage(pageDto), queryChainWrapper);
    }

    /**
     * 构建分页入参
     *
     * @param pageDto 分页参数
     * @param <T>     泛型
     * @return IPage
     */
    private static <T> IPage<T> buildQueryPage(@NonNull PageDto pageDto) {
        boolean searchCount = Objects.isNull(pageDto.getSearchCount()) || pageDto.getSearchCount();
        return Page.of(pageDto.getPageNo(), pageDto.getPageSize(), searchCount);
    }

    @Override
    public void setApplicationContext(@NonNull ApplicationContext applicationContext) throws BeansException {
        this.applicationContext = applicationContext;
    }
}
